"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


computes the values used to draw the schematic example of Fig. 1 

"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork
from TemporalStability import (FlowIntegralClustering,
                               static_clustering, avg_norm_var_information, Partition,
                               run_multi_louvain, norm_mutual_information)



def Round_To_n(x, n):
    return round(x, -int(np.floor(np.sign(x) * np.log10(abs(x)))) + n)

import matplotlib.pyplot as plt
import matplotlib as mpl
plt.style.use('alex_paper')
from palettable import cmocean, colorbrewer, cartocolors

import random
import igraph as ig
    
import leidenalg as la


raise Exception
#%%

#%% schematic network on Fig. 1
net = ContTempNetwork(source_nodes = [0   ,2    ,1   ,0   ,1   ,2   ],
                      target_nodes = [1   ,1    ,2   ,3   ,2   ,3   ],
                      starting_times=[0   ,0    ,1.2 ,2.2 ,4   ,4.2 ],
                      ending_times = [1.4 ,1    ,2.1   ,3.9 ,5   ,5   ])
net._merge_overlapping_events()
#%%

l = 1/1.6 # slow


# l = 1/0.8 # fast


net.compute_inter_transition_matrices(lamda=l)

net.compute_transition_matrices(lamda=l)

p0 = np.array([0,1,0,0])


for k in range(len(net.times)-1):
    print(k, net.times[k+1], f"{(p0@net.T[l][k].toarray())[0]:.2f}")
    
net_rev = ContTempNetwork(events_table=net.events_table)

net_rev.compute_inter_transition_matrices(lamda=l)

net_rev.compute_transition_matrices(lamda=l, reverse_time=True)


for k in range(len(net.times)-1):
    print(k, net.times[-2-k], f"{(p0@net_rev.T[l][k].toarray())[3]:.2f}")


fc = FlowIntegralClustering(T_inter_list=net.inter_T[l],
                            time_list=net.times)

fc_back = FlowIntegralClustering(T_inter_list=net.inter_T[l],
                            time_list=net.times,
                            reverse_time=True)

plt.matshow(fc.I_list[0].toarray())
plt.title('I_forw')

plt.matshow(fc_back.I_list[0].toarray())
plt.title('I_back')

#%% cluster trans probabilities


probs = {}

parts = {}

for l in [1/1.6, 1/0.8]:
    
    
    net.compute_inter_transition_matrices(lamda=l)
    net.compute_transition_matrices(lamda=l)

    fc = FlowIntegralClustering(T_inter_list=net.inter_T[l],
                                time_list=net.times)
    
    fc_back = FlowIntegralClustering(T_inter_list=net.inter_T[l],
                                time_list=net.times,
                                reverse_time=True)
    
    fc.find_louvain_clustering()
    fc_back.find_louvain_clustering()
    
    parts[l] = {}
    
    forw_part = fc.partition[0].cluster_list
    back_part = fc_back.partition[0].cluster_list
    
    parts[l]['forw'] = forw_part
    parts[l]['back'] = back_part
    
    T = net.T[l][-1].toarray()
    
    probs[l] = {}
    for i,fc in enumerate(forw_part):
        probs[l][i] = {}
        for j, bc in enumerate(back_part):
            
            probs[l][i][j] = T[np.ix_(list(fc),list(bc))].sum()/len(fc)




#%% scan

for tau_w in map(lambda x: Round_To_n(x,2),
                 np.logspace(np.log10(0.01), np.log10(100),num=20)):
    
    l = 1/tau_w
    
    net.compute_inter_transition_matrices(lamda=l)

    net.compute_transition_matrices(lamda=l)
    
    fc = FlowIntegralClustering(T_inter_list=net.inter_T[l],
                                time_list=net.times)

    fc_back = FlowIntegralClustering(T_inter_list=net.inter_T[l],
                                time_list=net.times,
                                reverse_time=True)
    
    fc.find_louvain_clustering()
    fc_back.find_louvain_clustering()
    
    print(tau_w, fc.partition[0].cluster_list, fc_back.partition[0].cluster_list)
    


#%%

A = net.compute_static_adjacency_matrix().toarray()

sc = static_clustering(A, discrete_time_rw=True)

plt.matshow(sc._S)
plt.title('B all')

sc.find_louvain_clustering(rnd_seed=4)
print(sc.partition)
print(sc.compute_stability())

#%% multislice


num_slices = 4

time_slices = np.linspace(net.start_time, net.end_time, num_slices+1)



t_starts = time_slices[:-1]
t_stops = time_slices[1:]

adjacencies = []

for ts,te in zip(t_starts,t_stops):
    A = net.compute_static_adjacency_matrix(start_time=ts,
                                            end_time=te).toarray()
    
    
    if not (A == np.zeros_like(A)).all():
        adjacencies.append((A))
    else:
        adjacencies.append((np.eye(A.shape[0])))



graphs = []
for A in adjacencies:
    g = ig.Graph()
    g.add_vertices(A.shape[0])
    g.vs['id'] = list(range(A.shape[0]))
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            if j > i :
                if A[i,j]>0:
                    g.add_edge(i,j, weight=A[i,j])
    graphs.append(g)




avg_weight = np.mean([A[A>0].mean() for A in adjacencies])


#%%


res = {}

for interslice_weight in [0,0.001,0.01,0.1,0.5,1,avg_weight]:
    
    res[interslice_weight] = {}
        
    layers, interslice_layer, G_full = \
                la.time_slices_to_layers(graphs,
                                interslice_weight=interslice_weight)
                
    A_supra = G_full.get_adjacency(attribute='weight')
    plt.matshow(np.array(A_supra.data))
    
    
    clusters_multi = {}
    stab_multi = {}
    
    num_repeat=100
    
    for resolution_parameter in map(lambda x: Round_To_n(x,3),
                     np.logspace(np.log10(0.3), np.log10(35),num=15)):
        
        print(resolution_parameter)
        # if resolution_parameter not in res.keys():
        clusters_multi[resolution_parameter] = []
        stab_multi[resolution_parameter] = []
        
        for i in range(100):
            print(i)
            optimiser = la.Optimiser()
            sysran = random.SystemRandom()
            optimiser.set_rng_seed(sysran.randint(-1e6,1e6))
                
                           
            partitions = [la.RBERVertexPartition(H,
                                                       weights='weight',
                                      resolution_parameter=resolution_parameter)
                                for H in layers];               
            
            interslice_partition = \
                           la.RBERVertexPartition(interslice_layer, 
                                                             resolution_parameter=0,
                                                             # node_sizes='node_size',
                                                             weights='weight')                           
                           
            diff = optimiser.optimise_partition_multiplex(partitions + [interslice_partition])            
                           
            clusters_multi[resolution_parameter].append(partitions[0].membership)
            stab_multi[resolution_parameter].append(sum([p.modularity for p in partitions + [interslice_partition]]))
            
            
    
    res_params = sorted(stab_multi.keys())
    best_clusts_multi = [clusters_multi[l][np.argmax(stab_multi[l])] for l in res_params]
    
    num_clusts_multi = [max(c)+1 for c in best_clusts_multi]
    
    
    parts_multi = {}
    for l in res_params:
        parts_multi[l] = []
        for clusts in clusters_multi[l]:
            parts_multi[l].append(Partition(num_nodes=(len(time_slices)-1)*A.shape[0],
                                            node_to_cluster_dict={i:c for i,c in enumerate(clusts)}).cluster_list)
    
    
    avg_nmi_multi = [avg_norm_var_information(parts_multi[l]) for l in res_params]

    res[interslice_weight]['avg_nvi_multi'] = avg_nmi_multi
    res[interslice_weight]['res_params'] = res_params
    res[interslice_weight]['num_clusts_multi'] = num_clusts_multi
    res[interslice_weight]['best_clusts_multi'] = best_clusts_multi
    
#%%
omegas = sorted(res.keys())
omega = omegas[-1]
print(omega)

plt.figure()
# plt.plot(res_params, avg_nmi_multi,'o-',)
plt.plot(res_params, res[omega]['avg_nvi_multi'] ,'o-',)
plt.title(f'NVI omega={omega:.3f}')
plt.xscale('log')   


plt.figure()
# plt.plot(res_params, avg_nmi_multi,'o-',)
plt.plot(res_params, res[omega]['num_clusts_multi'] ,'o-',)
plt.title(f'num clusts omega={omega:.3f}')
plt.xscale('log')   


#%%
i = 4

# clust_matrix = np.array(best_clusts_multi[i]).reshape((len(time_slices)-1,
#                                                        A.shape[0])).T
clust_matrix = np.array(res[omega]['best_clusts_multi'][i]).reshape((len(time_slices)-1,
                                                        A.shape[0])).T

plt.figure()
plt.imshow(clust_matrix)
plt.title(f'multislice, res={res_params[i]}, NVI={res[omega]["avg_nvi_multi"][i]:.2f}, omega={omega:.3f}')

# plt.savefig('../figures/schema/schema_small_rev_multislice.png')
#%% for each slice

Sis = []

for i,A in enumerate(adjacencies):
    
    sc = static_clustering(A, discrete_time_rw=True)
    Sis.append(sc._S)
    
    plt.matshow(A)
    plt.title(f'A {i}')
    
    plt.matshow(sc._S)
    plt.title(f'B {i}')
    
    n_loops, cluster_lists, stabilities, seeds = run_multi_louvain(sc, 100)
    assert avg_norm_var_information(cluster_lists) == 0.0
    print(i, cluster_lists[0])
    
#%% flow diagram

# for slow rw
source_comms = [{0}, {1, 2}, {3}]
target_comms = [{0, 3}, {1, 2}]

# for fast
source_comms = [{0, 1, 2}, {3}]
target_comms = [{0}, {1, 2, 3}]


class_dict = {i : {i} for i in range(4)}

flows = []
for clas, clas_set in class_dict.items():
    for s, comm_s in enumerate(source_comms):
        for t, comm_t in enumerate(target_comms):
            val = len(clas_set.intersection(comm_s).intersection(comm_t))
            if val > 0:
                flows.append({'source': int(s), 'target': int(t), 'type': int(clas), 'value': int(val)})

df_flows = pd.DataFrame.from_dict(flows)

# start from 1
df_flows.source += 1
df_flows.target += 1
df_flows.type += 1

import floweaver as flo

color_list =["#613D94",
"#D0730F",
"#359455",
"#CB5B5A",
]




    
sources_list = df_flows.source.unique().tolist()

sources_list = np.array(sources_list)

targets_list = df_flows.target.unique().tolist()

targets_list = np.array(targets_list)

class_list = sorted(df_flows.type.unique().tolist())


nodes = {
    'forward': flo.ProcessGroup(sources_list.tolist()),
    'backward': flo.ProcessGroup(targets_list.tolist()),
}

ordering = [
    ['forward'],       # put "forward" on the left...
    ['backward'],   # ... and "backward" on the right.
]

bundles = [
    flo.Bundle('forward', 'backward'),
]
# The first argument is the dimension name -- for now we're using
# "process" to group by process ids. The second argument is a list
# of groups.
source_part = flo.Partition.Simple('process', sources_list.tolist())

# This is another partition.
target_part = flo.Partition.Simple('process', targets_list.tolist())

# Update the ProcessGroup nodes to use the partitions
nodes['forward'].partition = source_part
nodes['backward'].partition = target_part
nodes['forward'].title = 'Forward'
nodes['backward'].title = 'Backward'

# Another partition -- but this time the dimension is the "type"
# column of the flows table
class_flow_part = flo.Partition.Simple('type', class_list)

# Set the colours for the labels in the partition.
palette = {cl : col for cl, col in zip(class_list, color_list)}

# New SDD with the flow_partition set
sdd = flo.SankeyDefinition(nodes, bundles, ordering,
                       flow_partition=class_flow_part)

w = flo.weave(sdd, df_flows, palette=palette)

#%% save flow and render svg
import json

json_file = os.path.join('../figures/paper_figures/','schema_flow_slow_rev.json')
svg_file = os.path.join('../figures/paper_figures/','schema_flow_slow_rev.svg')

json_file = os.path.join('../figures/paper_figures/','schema_flow_fast_rev.json')
svg_file = os.path.join('../figures/paper_figures/','schema_flow_fast_rev.svg')

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
    
    
with open(json_file, 'w') as fopen:
    json.dump(w.to_json(), fopen, cls=NpEncoder)
    


os.system(f'svg-sankey --size 250,200 --font-size 10 --margins 2,2 {json_file}  > {svg_file}')    

#remove node labels

# os.system(f"sed -i 's/>w[b-f][0-9]_[0-9][0-9]<\/text/><\/text/g' {svg_file}")


#%% static but varying the length of the first and last slices (Tab. S1)

res_len_slices = {}

nmis_slow_len = {}
nmis_fast_len= {}

add_self_edges = True

slice_lens = np.linspace(net.start_time, net.end_time, 11)[1:]

for slice_len in slice_lens:
    
    print(slice_len)
    
    res_len_slices[slice_len] = {}
 
    t_starts = [0,5-slice_len]
    t_stops = [0+slice_len,5]
    
    res_len_slices[slice_len]['adjacencies'] = []
    res_len_slices[slice_len]['modularity_matrices'] = []
    res_len_slices[slice_len]['sc_partitions'] = []
    res_len_slices[slice_len]['avg_nvi'] = []
    
    for ts,te in zip(t_starts,t_stops):
        A = net.compute_static_adjacency_matrix(start_time=ts,
                                                end_time=te).toarray()

        
        
        if add_self_edges:
            # count duration of when a node is active
            active_times = np.zeros(net.num_nodes)
            for k,(ti,tj) in enumerate(zip(net.times[:-1],net.times[1:])):
                if ti < te and tj > ts:
                    Atemp = net.compute_static_adjacency_matrix(start_time=ti,
                                                            end_time=tj).toarray()
                    # number of neighbours for each interevent time
                    d = (Atemp>0).sum(1)
                    active_times[d>0] += min(tj,te)-max(ti,ts) # add dt only to active nodes
                    
            np.fill_diagonal(A,te-ts - active_times)
            
               
        
        res_len_slices[slice_len]['adjacencies'].append((A))
        
        if np.allclose(A, np.zeros_like(A)):
            res_len_slices[slice_len]['modularity_matrices'].append(None)
            res_len_slices[slice_len]['sc_partitions'].append(None)
        else:
        
            sc = static_clustering(A, discrete_time_rw=True)
            _, cluster_lists, stabilities, _, = run_multi_louvain(sc, 50)
            res_len_slices[slice_len]['avg_nvi'] = avg_norm_var_information(cluster_lists)
            
            if res_len_slices[slice_len]['avg_nvi'] >0:
                print(num_slices, res_len_slices[slice_len]['avg_nvi'])
            
            res_len_slices[slice_len]['modularity_matrices'].append(sc._S)
            
            # sc.find_louvain_clustering()
        
            res_len_slices[slice_len]['sc_partitions'].append(cluster_lists[np.argmax(stabilities)])
            
    nmis_slow_len[slice_len] = {}
    nmis_fast_len[slice_len] = {}

    nmis_fast_len[slice_len]['forw_fast_nmi'] = norm_mutual_information(parts[1.25]['forw'], res_len_slices[slice_len]['sc_partitions'][0])
    nmis_slow_len[slice_len]['forw_slow_nmi'] = norm_mutual_information(parts[0.625]['forw'], res_len_slices[slice_len]['sc_partitions'][0])
    nmis_fast_len[slice_len]['back_fast_nmi'] = norm_mutual_information(parts[1.25]['back'], res_len_slices[slice_len]['sc_partitions'][-1])
    nmis_slow_len[slice_len]['back_slow_nmi'] = norm_mutual_information(parts[0.625]['back'], res_len_slices[slice_len]['sc_partitions'][-1])
    
#%% results as a table

import pandas as pd

df_parts = pd.DataFrame([res_len_slices[l]['sc_partitions'] for l in slice_lens], columns=['early', 'late'], index=slice_lens/5)
df_nmis_fast = pd.DataFrame([[nmis_fast_len[l]['forw_fast_nmi'],nmis_fast_len[l]['back_fast_nmi']] for l in slice_lens], columns=['fast forw', 'fast back'], index=slice_lens/5)    
df_nmis_slow = pd.DataFrame([[nmis_slow_len[l]['forw_slow_nmi'],nmis_slow_len[l]['back_slow_nmi']] for l in slice_lens], columns=['slow forw', 'slow back'], index=slice_lens/5)    
df = pd.concat([df_parts,df_nmis_fast, df_nmis_slow],axis=1)

print(df.to_latex(float_format="%.2f"))
